# *- coding:utf-8 -*-
import re
import csv
import numpy as np
import time
import heapq
import random
"""
    1. 文件名配置
"""
SUBMIT = True
if SUBMIT:
    config_file = '/data/config.ini'
    demand_file = '/data/demand.csv'
    qos_file = '/data/qos.csv'
    site_bandwidth_file = '/data/site_bandwidth.csv'
    output_file = '/output/solution.txt'
else:
    config_file = 'data/config.ini'
    demand_file = 'data/demand.csv'
    qos_file = 'data/qos.csv'
    site_bandwidth_file = 'data/site_bandwidth.csv'
    output_file = 'output/solution.txt'
"""
    2. 模型参数
"""
class Config:
    def __init__(self):
        # 训练次数
        self.train_num = 1000
        # 轮次
        self.num_epoch = 1
        self.first_num_epoch = 10
        self.early_stop = 100

        # adam优化器参数
        self.lr = 0.05
        self.beta1 = 0.9
        self.beta2 = 0.999
        self.eps = 1e-8
"""
    3. 工具包
"""

class My_PriorityQueue(object):
    def __init__(self):
        self._queue = []
        self._index = 0

    def push(self, item, priority):
        """
        队列由 (priority, index, item) 形式组成
        priority 增加 "-" 号是因为 heappush 默认是最小堆
        index 是为了当两个对象的优先级一致时，按照插入顺序排列
        """
        heapq.heappush(self._queue, (-priority, self._index, item))
        self._index += 1

    def pop(self):
        """
        弹出优先级最高的对象
        """
        return heapq.heappop(self._queue)[-1]

    def qsize(self):
        return len(self._queue)

    def empty(self):
        return True if not self._queue else False
"""
    4. 数据输入
"""
with open(config_file, 'r', encoding='utf-8') as f:
    qos = int(re.search(r'qos_constraint=(\d+)', f.read()).group(1))

with open(demand_file, 'r', encoding='utf-8') as f:
    lines = csv.reader(f)
    header = next(lines)[1:]
    demand = np.array([[int(item) for item in line[1:]] for line in lines])
    client_node2id = {}
    for idx, item in enumerate(header):
        client_node2id[item] = idx

with open(qos_file, 'r', encoding='utf-8') as f:
    lines = csv.reader(f)
    header = next(lines)[1:]
    edge_node2id = {}
    lines = list(lines)
    num_edge, num_client = len(lines), len(header)
    edge_client_qos = np.zeros((num_edge, num_client), dtype=np.int32)
    idx_tran_tmp = {}
    for idx, item in enumerate(header):
        idx_tran_tmp[idx] = client_node2id[item]
    for idx, line in enumerate(lines):
        edge_node2id[line[0]] = idx
        for idx2, value in enumerate(line[1:]):
            edge_client_qos[idx, idx_tran_tmp[idx2]] = int(value)

with open(site_bandwidth_file, 'r', encoding='utf-8') as f:
    site_bandwidth = np.zeros((num_edge), dtype=np.int32)
    lines = csv.reader(f)
    next(lines)
    for line in lines:
        site_bandwidth[edge_node2id[line[0]]] = int(line[1])

# qos : int 阈值
# demand : np.Array 需求 [时刻数 * 客户数]
# edge_client_qos : np.Array 时延 [边缘节点 * 客户节点]
# site_bandwidth : np.Array 带宽限制 [边缘节点]
# client_node2id : dic str->idx
# edge_node2id : dic str->idx

"""
    4. 选择top4的节点无限使用
"""
T,C = demand.shape
E, = site_bandwidth.shape
result = np.zeros((T,E,C),dtype=int)
eps = 1e-8

mask = edge_client_qos < qos
# mask矩阵 mask[e,c]为True表示边缘节点e可以分配给用户节点c

time_demand = demand.sum(axis = 1).tolist()
pq = My_PriorityQueue()
for idx,value in enumerate(time_demand):
    pq.push(idx,value)

# time_demand 每个时间点所需要的流量
need_edges = mask.sum(axis = 1) > 0 # [E]
num_set = 0
edge_left = np.array([int(T*0.05) for i in range(E)]) * need_edges
customer_priority = mask.sum(axis = 0).tolist() # [C]
sorted_priority = [(customer_id,customer_value)for customer_id,customer_value in enumerate(customer_priority)]
sorted_priority.sort(key=lambda x:x[1])
sorted_priority = [customer_id for customer_id,customer_value in sorted_priority]

# 这里没有完全利用
sum_need_edges = np.sum(need_edges)
while not pq.empty() and num_set < sum_need_edges * (int(T*0.05)):
    max_t = pq.pop()
    demand_t = demand[max_t] # [C]
    mask_d = mask
    # edge_t = np.matmul(mask_d,demand_t) # 每个边缘节点的加权值
    # edge_t_list = [(edge_id,edge_value) for edge_id,edge_value in enumerate(edge_t)]
    # edge_t_list.sort(key=lambda x:x[1],reverse=True)
    edge_t_list = list(range(E))
    random.shuffle(edge_t_list)
    # 针对t时间
    select = False
    for edge_id in edge_t_list:
        if edge_left[edge_id] == 0 or result[max_t,edge_id].sum()>0:
            continue
        # 按照优先级选择客户节点，优先选择可分配用户少的
        cur_value = 0
        for customer_id in sorted_priority:
            if not mask[edge_id,customer_id]:
                continue
            tran = min(demand[max_t,customer_id],site_bandwidth[edge_id]-cur_value)
            cur_value += tran
            demand[max_t,customer_id] -= tran
            result[max_t,edge_id,customer_id] += tran
            if cur_value == site_bandwidth[edge_id]:
                break
        if cur_value>0:
            edge_left[edge_id] -= 1
            select = True
            break
    if select: # 尽最大可能选择完毕
        pq.push(max_t,demand[max_t].sum())
        num_set += 1

top4 = result.sum(axis=2)
top4_mask = top4 == 0

"""
    5. 模型搭建
"""
class Model(object):
    def __init__(self,mask,demand,site_bandwidth,top4_mask,config):
        self.T,_,self.C = demand.shape
        _,self.E = site_bandwidth.shape
        # self.X = np.random.normal(size=(self.T,self.E,self.C))*4 # [ T, E, C] norm_dim = 1
        self.X = np.zeros((self.T,self.E,self.C)) # [ T, E, C] norm_dim = 1
        self.mask = mask # [1, E, C]
        self.top4_mask = top4_mask # [T,E,1]
        self.demand = demand # [T,1,C]
        self.site_bandwidth = site_bandwidth # [1,E]
        self.best_X = np.zeros((self.T,self.E,self.C)) # [ T, E, C] norm_dim = 1
        self.min_loss = -1

        self.config = config
        self.X_m = np.zeros((self.T,self.E,self.C))
        self.X_v = np.zeros((self.T,self.E,self.C))
        self.X_m_hat = np.zeros((self.T,self.E,self.C))
        self.X_v_hat = np.zeros((self.T,self.E,self.C))

        X_tmp = np.exp(self.X - np.max(self.X,axis=1,keepdims=True))
        self.X_norm = X_tmp / (np.sum(X_tmp,axis=1,keepdims=True) + self.config.eps)
        self.X_masked = self.X_norm * self.mask * self.top4_mask
        self.X_prob = self.X_masked / (np.sum(self.X_masked,axis=1,keepdims=True)+self.config.eps) # [ T, E, C]
        self.edge_use = (self.X_prob * self.demand).sum(axis=2)                  # [ T, E]

    # time 为挑选次数
    # 选择top_time
    def select_top(self,time):
        if time == 0:
            tmp = self.demand.sum(axis = 1)
            self.top_time = heapq.nlargest(max(int(self.T*0.2),1),range(len(tmp)),tmp.take) # [C]
            self.top_demand = self.demand[self.top_time]
            self.top_4mask = self.top4_mask[self.top_time]
        else:
            self.edge_use[self.top_time] = self.top_edge_use                  # [ T, E]
            self.top_time = list(set(self.edge_use.argmax(axis=0))) # [E]    
            self.top_demand = self.demand[self.top_time]
            self.top_4mask = self.top4_mask[self.top_time]
        # self.edge_use_mean = self.edge_use.sum(axis = 0,keepdims=True) / (int(0.95*self.T) + self.config.eps)
        self.edge_max = self.edge_use.max(axis = 0,keepdims=True) # 
    
    def forward(self):
        self.top_x = self.X[self.top_time]
        X_tmp = np.exp(self.top_x - np.max(self.top_x,axis = 1,keepdims = True))
        self.top_norm = X_tmp / (np.sum(X_tmp,axis = 1,keepdims=True)+self.config.eps)
        self.top_masked = self.top_norm * self.mask * self.top_4mask
        self.top_prob = self.top_masked / (np.sum(self.top_masked,axis=1,keepdims=True) + self.config.eps)
        self.top_edge_use = (self.top_prob * self.top_demand).sum(axis = 2)
        # self.top_edge_use_mean = self.top_edge_use.mean(axis = 0,keepdims=True)

        self.loss = np.sum(np.max(self.top_edge_use,axis=0))
        self.idx = np.argmax(self.top_edge_use,axis=0)
        # print(self.loss)
        # if self.min_loss == -1 or self.loss < self.min_loss:
        #     self.min_loss = self.loss
        #     self.best_X = self.X.copy()
        print('batch',self.loss)

    def backward(self):
        self.top_T = len(self.top_time)
        tmp = self.top_edge_use - self.site_bandwidth
        self.d_top_edge_use = np.zeros((self.top_T,self.E))
        self.d_top_edge_use[self.idx,np.arange(self.E)] = 1 # 最大值损失
        self.d_top_edge_use += 1 * (self.top_edge_use - self.edge_max) + \
                               100 * tmp * (tmp>0) # 超出限制损失
        self.d_top_prob = np.expand_dims(self.d_top_edge_use,axis = 2) * self.top_demand
        top_sum_x = np.sum(self.top_masked,axis=1,keepdims=True)
        self.d_top_masked = 1/(top_sum_x + self.config.eps) * (self.d_top_prob - np.sum(self.d_top_prob*self.top_prob,axis=1,keepdims=True))
        self.d_top_norm = self.d_top_masked * self.mask * self.top_4mask
        self.d_top_x = self.top_norm * (self.d_top_norm - np.sum(self.d_top_norm*self.top_norm,axis=1,keepdims=True))

        # 更新参数
    def update(self,t):
        # 更新参数 adam优化器
        self.X_m[self.top_time] = self.config.beta1*self.X_m[self.top_time]+(1-self.config.beta1)*self.d_top_x
        self.X_v[self.top_time] = self.config.beta2*self.X_v[self.top_time]+(1-self.config.beta2)*self.d_top_x*self.d_top_x
        self.X_m_hat[self.top_time] = self.X_m[self.top_time] / (1-np.power(self.config.beta1,t+1))
        self.X_v_hat[self.top_time] = self.X_v[self.top_time] / (1-np.power(self.config.beta2,t+1))

        update_X = self.config.lr*self.X_m_hat[self.top_time]/(np.sqrt(self.X_v_hat[self.top_time])+self.config.eps)

        self.X[self.top_time] = self.X[self.top_time] - update_X
        # self.X[self.top_time] = self.X[self.top_time] - self.d_top_x

    def train(self):
        for time in range(self.config.train_num):
            self.select_top(time)
            if time % 10 == 0:
                self.config.lr *= 0.9
            if time == 0 :
                num_epoch = self.config.first_num_epoch
            else:
                num_epoch = self.config.num_epoch
            for epoch in range(num_epoch):
                self.forward()
                self.backward()
                self.update(time*self.config.num_epoch*epoch)
            # self.caculate_cost()

    def caculate_cost(self):
        X_tmp = np.exp(self.X - np.max(self.X,axis=1,keepdims=True))
        self.X_norm = X_tmp / (np.sum(X_tmp,axis=1,keepdims=True) + self.config.eps)
        self.X_masked = self.X_norm * self.mask * self.top4_mask
        self.X_prob = self.X_masked / (np.sum(self.X_masked,axis=1,keepdims=True) + self.config.eps) # [ T, E, C]
        self.edge_use = (self.X_prob * self.demand).sum(axis=2)     # [ T, E]
        # self.edge_mean = self.edge_use.sum(axis=0)/(int(T*0.95))
        # self.var = np.power(self.edge_use-self.edge_mean,2)
        # self.loss = np.sum(self.var) # 真正损失为最小化方差
        # 计算花费b2
        self.loss = np.sum(np.max(self.edge_use,axis=0))
        print('all_loss',self.loss)

    def decode(self):
        X_tmp = np.exp(self.X - np.max(self.X,axis=1,keepdims=True))
        self.X_norm = X_tmp / (np.sum(X_tmp,axis=1,keepdims=True) + self.config.eps)
        self.X_masked = self.X_norm * self.mask * self.top4_mask
        self.X_prob = self.X_masked / (np.sum(self.X_masked,axis=1,keepdims=True) + self.config.eps) # [ T, E, C]
        self.d_answer = self.X_prob * self.demand
        self.edge_use = self.d_answer.sum(axis=2)     # [ T, E]
        self.edge_topk  = np.max(self.edge_use,axis=0)
        self.answer = np.zeros((self.T,self.E,self.C),dtype=int)
        self.site_bandwidth = self.site_bandwidth.squeeze()
        self.demand = self.demand.squeeze()
        self.mask = self.mask.squeeze()
        self.top4_mask = self.top4_mask.squeeze()

        for t in range(self.T):
            """小数部分"""
            for c in range(self.C):
                for e in range(self.E):
                    self.answer[t,e,c] = int(self.d_answer[t,e,c])
                need_to_distribution = self.demand[t,c] - np.sum(self.answer[t,:,c])
                e = 0
                while need_to_distribution > 0 and e < self.E:
                    if np.sum(self.answer[t,e,:]) < int(self.edge_topk[e]) and self.mask[e,c] and self.top4_mask[t,e]:
                        distribution = min(need_to_distribution,int(self.edge_topk[e]) - np.sum(self.answer[t,e,:]))
                        # assert distribution >= 0
                        self.answer[t, e, c] += distribution
                        need_to_distribution -= distribution
                    elif np.sum(self.answer[t,e,:]) > int(self.edge_topk[e]) and self.mask[e,c] and self.top4_mask[t,e]\
                            and need_to_distribution>0 and np.sum(self.answer[t,e,:])<self.site_bandwidth[e] :
                        distribution = min(need_to_distribution,self.site_bandwidth[e] - np.sum(self.answer[t,e,:]) )
                        # assert distribution >= 0
                        self.answer[t, e, c] += distribution
                        need_to_distribution -= distribution
                    e += 1
                if e == self.E and need_to_distribution>0:
                    e = 0
                    while need_to_distribution > 0 and e < self.E:
                        if np.sum(self.answer[t,e,:]) == int(self.edge_topk[e]) and self.mask[e,c] and self.top4_mask[t,e]\
                            and need_to_distribution > 0 and np.sum(self.answer[t, e, :]) < self.site_bandwidth[e]:
                            distribution = min(need_to_distribution, self.site_bandwidth[e] - np.sum(self.answer[t,e,:]))
                            # assert distribution > 0
                            self.answer[t, e, c] += distribution
                            need_to_distribution -= distribution
                        e += 1
                # assert need_to_distribution == 0
            """超出部分"""
            for e in range(self.E):
                # 没超过边缘节点阈值
                if np.sum(self.answer[t,e,:]) <= self.site_bandwidth[e]:
                    continue
                need = np.sum(self.answer[t,e,:]) - self.site_bandwidth[e]
                # 超过边缘节点阈值
                for c in range(self.C):
                    # e 对应的 c节点分配给其他的e
                    for e2 in range(self.E):
                        if not self.mask[e2, c] or e == e2 or np.sum(self.answer[t,e2,:]) >= self.site_bandwidth[e2] or not self.top4_mask[t,e2]:
                            continue
                        # 需要多少 和 能提供多少 取最小值
                        distribution = min(need,self.site_bandwidth[e2] - np.sum(self.answer[t,e2,:]),self.answer[t,e,c])
                        # assert distribution > 0
                        need -= distribution
                        self.answer[t,e,c] -= distribution
                        self.answer[t,e2,c] += distribution
                        if need == 0:
                            break
                    if need == 0:
                        break
                # assert need == 0

    def output_file(self,edge_node2id,client_node2id):
        output_list = []
        cost = self.answer.sum(axis = 2).max(axis=0).sum()
        print('cost final:',cost) # 损失一边降一边选
        self.answer = self.answer + result # 加上第一步的结果
        for t in range(self.T):
            for client_name,client_id in client_node2id.items():
                string = f'{client_name}:'
                first = True
                for edge_name,edge_id in edge_node2id.items():
                    if self.answer[t,edge_id,client_id] >0:
                        if not first:
                            string += ','
                        string += f'<{edge_name},{self.answer[t,edge_id,client_id]}>'
                        first = False
                output_list.append(string)
        with open(output_file,'w',encoding='utf-8') as f:
            f.write('\n'.join(output_list))
"""
    主函数
"""
if __name__ == '__main__':
    start = time.time()
    config = Config()
    m = Model(np.expand_dims(mask,0),np.expand_dims(demand,1),np.expand_dims(site_bandwidth,0),np.expand_dims(top4_mask,2),config)
    m.train()
    m.decode()
    m.output_file(edge_node2id,client_node2id)

    